```{r boilerplate}
library(tidyverse)
library(gbm)
library(policytree)
library(ggrepel)
library(glmnet)
library(cobs)
library(stringr)
source("../code/pitr.R")

knitr::opts_chunk$set(dev = "cairo_pdf")
```


```{r load_data}
# read in RHC data and get outcomes, treatment, and covariates

rhc <- read_csv("rhc.csv") %>%
  mutate(across(where(is.character), as.factor),
         trt = 1 * (swang1 == "RHC"),
         y = 1 * (dth30 != "Yes"))

covs <- c("age", "sex","race", "edu", "income", "ninsclas", "cat1",
                  "cardiohx", "chfhx", "dementhx", "psychhx", "chrpulhx", "renalhx", "liverhx",
                  "gibledhx", "malighx", "immunhx", "transhx", "amihx","das2d3pc","dnr1", "ca", "surv2md1", "aps1", 
                  "scoma1","wtkilo1", "temp1", "meanbp1", "resp1", "hrt1", "pafi1", "paco21","ph1","wblc1", "hema1",
                  "sod1", "pot1", "crea1", "bili1", "alb1",
                  "resp", "card", "neuro", "gastr", "renal", "meta", "hema", "seps", "trauma", "ortho" )

# covariates without SES
ses_covs <- c("race", "edu", "income", "ninsclas")
health_covs <- setdiff(covs, ses_covs)


X_all <- model.matrix(~ . - 1, rhc[, covs])

# create sum of complications
sum_comp <- rowSums(X_all[, str_detect(colnames(X_all), "Yes")])

X_health <- model.matrix(~ . - 1, rhc[, health_covs])
X_health <- apply(cbind(X_health, sum_comp), 2, function(x) (x - min(x)) / (max(x) - min(x)))





n <- nrow(rhc)
```


```{r crossfit_func}

crossfit <- function(tr_dat, test_dat, form, ...) {

  n_tr <- nrow(tr_dat)
  n_tst <- nrow(test_dat)
  # split into 3 folds
  folds <- split(sample(n_tr), cut(1:n_tr, 3))

  tr_pred <- numeric(n_tr)
  tst_pred <- matrix(0, nrow = n_tst, ncol = 3)
  k <- 1
  for(fold in folds) {

    tr_idx <- setdiff(1:n_tr, fold)
    gb <- gbm(form, data = tr_dat[tr_idx,], ...)
    best_iter <- gbm.perf(gb)

    tr_pred[fold] <- predict(gb, tr_dat[fold,], n.trees = best_iter, type = "response")
    tst_pred[, k] <- predict(gb, test_dat, n.trees = best_iter, type = "response")
    k <- k + 1
  }
  tst_pred <- apply(tst_pred, 1, median)
  return(list(tr = tr_pred, tst = tst_pred))
}


crossfit_outcome <- function(dat, trt, form, ...) {

  n <- nrow(dat)
  # split into 3 folds
  folds <- split(sample(n), cut(1:n, 3))

  pred0 <- numeric(n)
  pred1 <- numeric(n)
  k <- 1
  for(fold in folds) {

    tr_idx <- setdiff(1:n, fold)
    gb0 <- gbm(form, data = dat[tr_idx[trt[fold] == 0],], ...)
    best_iter0 <- gbm.perf(gb0)
    gb1 <- gbm(form, data = dat[tr_idx[trt[fold] == 1],], ...)
    best_iter1 <- gbm.perf(gb1)

    pred0[fold] <- predict(gb0, dat[fold,], n.trees = best_iter0, type = "response")
    pred1[fold] <- predict(gb1, dat[fold,], n.trees = best_iter1, type = "response")
    k <- k + 1
  }
  return(list(mu0 = pred0, mu1 = pred1))
}


crossfit_both <- function(dat, trt, form_out, ...) {

  n <- nrow(dat)
  # split into 3 folds
  folds <- split(sample(n), cut(1:n, 3))

  pred0 <- numeric(n)
  pred1 <- numeric(n)
  pred_pscore <- numeric(n)
  outcome <- as.character(form_out)[2]
  pscore_dat <- dat %>% select(!(outcome)) %>% mutate(trt = trt)

  k <- 1
  for(fold in folds) {

    tr_idx <- setdiff(1:n, fold)
    gb0 <- gbm(form_out, data = dat[tr_idx[trt[tr_idx] == 0],], ...)
    best_iter0 <- gbm.perf(gb0)
    gb1 <- gbm(form_out, data = dat[tr_idx[trt[tr_idx] == 1],], ...)
    best_iter1 <- gbm.perf(gb1)
    
    gb_pscore <- gbm(trt ~ . , data = pscore_dat[tr_idx,], ...)
    best_iter <- gbm.perf(gb_pscore)

    pred0[fold] <- predict(gb0, dat[fold,], n.trees = best_iter0, type = "response")
    pred1[fold] <- predict(gb1, dat[fold,], n.trees = best_iter1, type = "response")
    pred_pscore[fold] <- predict(gb_pscore, pscore_dat[fold,], n.trees = best_iter, type = "response")
    k <- k + 1
  }
  return(list(mu0 = pred0, mu1 = pred1, pscore = pred_pscore, folds = folds))
}
```


```{r estimate_nuis, cache = TRUE, include = F}

# estimate the outcome models and propensity score for DR policy learning
# use gradient boosting
nuis_funcs <- crossfit_both(rhc %>% select(all_of(covs), y),
                1 * (rhc$swang1 == "RHC"),
                as.formula(y ~ .), distribution = "bernoulli", n.trees = 5000,
                shrinkage = 0.01, interaction.depth = 1)

mu0_hat <- nuis_funcs$mu0
mu1_hat <- nuis_funcs$mu1
pscore <- nuis_funcs$pscore

```




```{r dr_treatment_estimate}

gamma_0 <- mu0_hat + (rhc$y - mu0_hat) * (1 - rhc$trt) / (1 - pscore) / mean((1 - rhc$trt) / (1 - pscore))
gamma_1 <-  mu1_hat + (rhc$y - mu1_hat) * rhc$trt / pscore / mean(rhc$trt / pscore)

gamma <- gamma_1 - gamma_0

mean(gamma)

sqrt(1 / n * mean((gamma - mean(gamma))^2))


```




# Two dimensional example



```{r fit_plugin_models, include = F}

less_covs <- c("dnr1Yes", "surv2md1")

X_ravel <- expand.grid(dnr1Yes = c(0, 1), surv2md1 = seq(min(X_health[,less_covs[2]]), max(X_health[,less_covs[2]]), length.out = 100))

less_covs <- c("dnr1Yes", "surv2md1")

less_dat <- rhc %>% select(dnr1, surv2md1)


# fit model for Y(0) | dnr + surv2md
m0_less <- crossfit(less_dat %>% mutate(gamma = gamma_0),
                X_ravel %>% rename(dnr1 = dnr1Yes),
                as.formula(gamma ~ .), n.trees = 500,
                shrinkage = 0.01, interaction.depth = 1)

m0_hat_less <- m0_less$tr

# fit model for Y(1) | dnr + surv2md
m1_less <- crossfit(less_dat %>% mutate(gamma = gamma_1),
                X_ravel %>% rename(dnr1 = dnr1Yes),
                as.formula(gamma ~ .), n.trees = 500,
                shrinkage = 0.01, interaction.depth = 1)

m1_hat_less <- m1_less$tr


pos_class_less_plugin <- 1 * (m1_hat_less + m0_hat_less - 1 >= 0)

tau_pos_less_plugin <- 1 * (m1_hat_less - m0_hat_less >= 0)



```

```{r two_dim_example_threshold, fig.width = 5, fig.height = 4}




# separately find the best modeled survival prob thresholds
# for the DNR and no DNR populations
compute_surv_threshold <- function(ub) {

  idxdnr1 <- X_health[,"dnr1Yes"] == 1
  idxdnr0 <- X_health[,"dnr1Yes"] == 0
  dnr0_thresh <- min_regret_1d_threshold(
    X_health[idxdnr0,"surv2md1", drop = FALSE],
    -ub + pos_class_less_plugin[idxdnr0] * (ub - 1),
    1 + pos_class_less_plugin[idxdnr0] * (ub - 1),
    (1 - ub) * pos_class_less_plugin[idxdnr0],
    compute_dr_utility,
    trt = rhc$trt[idxdnr0],
    y = rhc$y[idxdnr0],
    pscore = pscore[idxdnr0],
    mu1 = mu1_hat[idxdnr0],
    mu0 = mu0_hat[idxdnr0]
  )

  dnr1_thresh <- min_regret_1d_threshold(
    X_health[idxdnr1,"surv2md1", drop = FALSE],
    -ub + pos_class_less_plugin[idxdnr1] * (ub - 1),
    1 + pos_class_less_plugin[idxdnr1] * (ub - 1),
    (1 - ub) * pos_class_less_plugin[idxdnr1],
    compute_dr_utility,
    trt = rhc$trt[idxdnr1],
    y = rhc$y[idxdnr1],
    pscore = pscore[idxdnr1],
    mu1 = mu1_hat[idxdnr1],
    mu0 = mu0_hat[idxdnr1]
  )

  data.frame(dnr1Yes = c(0,1),
             threshold = c(dnr0_thresh[1], dnr1_thresh[1]),
             sign =  c(dnr0_thresh[2], dnr1_thresh[2]),
             obj = c(dnr0_thresh[3], dnr1_thresh[3]),
             ub = c(ub, ub))
}



ubs_soc <- c(seq(.5,2, length.out = 10))
map_df(ubs_soc, compute_surv_threshold) -> soc_policies_thresh






soc_policies_thresh %>%
  ggplot() +
  geom_smooth(aes(x = surv2md1, y = gamma), color = "black", method = "loess",
             data = rhc %>% mutate(gamma= gamma, dnr1Yes = 1 *(dnr1 == "Yes"),
                surv2md1 = (surv2md1 - min(surv2md1)) / (max(surv2md1) - min(surv2md1))),
             lty = 2, fill = "grey80") +
  geom_hline(yintercept = 0, lty = 2) +
  geom_vline(aes(xintercept = threshold, color = ub)) +
  geom_rug(aes(x = surv2md1, y = 0), color = "black", alpha = 0.01, data = rhc, sides = "b") +
  facet_wrap(~ dnr1Yes, scales = "free_y", labeller = labeller(dnr1Yes = c("0"= "No DNR", "1"= "DNR"))) +
  xlab("Estimated probability of surviving 2 months") +
  ylab("Smoothed Treatment Effect") +
  scale_color_viridis_c(expression(u[l])) +
  guides(fill = guide_colorbar(title.position = "top", barheight = 0.5, barwidth = 10)) +
  theme_bw() +
  theme(legend.position = "bottom")


```

```{r two_dim_example_pos_class_plugin, fig.width = 4.5, fig.height = 3}

X_ravel %>%
  mutate(pred = m1_less$tst + m0_less$tst) %>%
  ggplot(aes(x = surv2md1, y = pred)) +
  geom_hline(yintercept = 1, lty = 2) +
  geom_line() +
  facet_wrap(~ dnr1Yes, scales = "free_y", labeller = labeller(dnr1Yes = c("0"= "No DNR", "1"= "DNR"))) +
  xlab("Estimated probability of surviving 2 months") +
  ylab(expression(m(1,x) - m(0,x))) +
  theme_bw()
```


```{r pharm_twodim, fig.width = 4.5, fig.height = 3}

rhc %>% mutate(gamma1 = gamma_1, gamma0 = gamma_0,
               dnr1Yes = 1 *(dnr1 == "Yes"),
               surv2md1 = (surv2md1 - min(surv2md1)) / (max(surv2md1) - min(surv2md1)),
               pos_class = 1 * (m0_less$tr + m1_less$tr >= 1),
               harm_bound_up = gamma0 + pos_class * (1 - gamma1 - gamma0),
               harm_bound_low = (m1_less$tr <= m0_less$tr) * (gamma1 - gamma0),
               help_bound = gamma1 + pos_class * (1 - gamma1 - gamma0),
               help_bound_low = (m1_less$tr >= m0_less$tr) * (gamma1 - gamma0))  %>%
  select(dnr1Yes, surv2md1, contains("bound")) %>%
  pivot_longer(-c(dnr1Yes, surv2md1)) %>%
  mutate(low = str_detect(name, "low"),
         harm = ifelse(str_detect(name,"harm"), "Proportion harmed", "Proportion helped")) %>%
  filter(!low) %>%
  ggplot(aes(x = surv2md1, y = value)) +
  geom_smooth(color = "black", fill = "grey50", method = "loess") +
  coord_cartesian(ylim = c(0,1)) +
  facet_grid(harm~ dnr1Yes, scales = "free_y", labeller = labeller(dnr1Yes = c("0"= "No DNR", "1"= "DNR")), switch = "y") +
  xlab("Estimated probability of surviving 2 months") +
  scale_y_continuous("Upper bound", labels = scales::percent) +
  theme_bw() +
  theme()
```


```{r two_dim_example_oracle_threshold, fig.width = 5, fig.height = 4}




# separately find the best modeled survival prob thresholds
# for the DNR and no DNR populations
# use a plug in for all nuisance classifiers
compute_surv_threshold_oracle <- function(ub) {


  soc_class_plugin <- ifelse(pos_class_less_plugin == 0,
                             1 * (m1_hat_less >=  ub / 1 * m0_hat_less),
                             1 * (m1_hat_less >= 1 / ub * m0_hat_less + (ub - 1) / ub))


  utils <- get_regret_oracle(tau_pos_less_plugin, pos_class_less_plugin, soc_class_plugin, 1, ub)

  idxdnr1 <- X_health[,"dnr1Yes"] == 1
  idxdnr0 <- X_health[,"dnr1Yes"] == 0
  dnr0_thresh <- min_regret_1d_threshold(
    X_health[idxdnr0,"surv2md1", drop = FALSE],
    utils$c0[idxdnr0],
    utils$c1[idxdnr0],
    utils$c_const[idxdnr0],
    compute_dr_utility,
    trt = rhc$trt[idxdnr0],
    y = rhc$y[idxdnr0],
    pscore = pscore[idxdnr0],
    mu1 = mu1_hat[idxdnr0],
    mu0 = mu0_hat[idxdnr0]
  )

  dnr1_thresh <- min_regret_1d_threshold(
    X_health[idxdnr1,"surv2md1", drop = FALSE],
    utils$c0[idxdnr1],
    utils$c1[idxdnr1],
    utils$c_const[idxdnr1],
    compute_dr_utility,
    trt = rhc$trt[idxdnr1],
    y = rhc$y[idxdnr1],
    pscore = pscore[idxdnr1],
    mu1 = mu1_hat[idxdnr1],
    mu0 = mu0_hat[idxdnr1]
  )

  data.frame(dnr1Yes = c(0,1),
             threshold = c(dnr0_thresh[1], dnr1_thresh[1]),
             sign =  c(dnr0_thresh[2], dnr1_thresh[2]),
             obj = c(dnr0_thresh[3], dnr1_thresh[3]),
             ub = c(ub, ub))
}



map_df(ubs_soc, compute_surv_threshold_oracle) -> oracle_policies_thresh





oracle_policies_thresh %>%
  ggplot() +
  geom_smooth(aes(x = surv2md1, y = gamma), color = "black", method = "loess",
             data = rhc %>% mutate(gamma= gamma, dnr1Yes = 1 *(dnr1 == "Yes"),
                surv2md1 = (surv2md1 - min(surv2md1)) / (max(surv2md1) - min(surv2md1))),
             lty = 2, fill = "grey80") +
  geom_hline(yintercept = 0, lty = 2) +
  geom_vline(aes(xintercept = threshold, color = ub)) +
  geom_rug(aes(x = surv2md1, y = 0), color = "black", alpha = 0.01, data = rhc, sides = "b") +
  facet_wrap(~ dnr1Yes, scales = "free_y", labeller = labeller(dnr1Yes = c("0"= "No DNR", "1"= "DNR"))) +
  xlab("Estimated probability of surviving 2 months") +
  ylab("Smoothed Treatment Effect") +
  scale_color_viridis_c(expression(u[l])) +
  guides(fill = guide_colorbar(title.position = "top", barheight = 0.5, barwidth = 10)) +
  theme_bw() +
  theme(legend.position = "bottom")



```



# Using multiple covariates

```{r cate_variable_importance, cache = T}

# variable importance for effect
rf <- grf::regression_forest(X_health, gamma_1 - gamma_0)
importances <- grf::variable_importance(rf)
import_df <- data.frame(name = factor(colnames(X_health),
                         levels = colnames(X_health)[order(importances)]),
           importance = importances)

import_df %>%
  ggplot(aes(x = importance, y = name)) +
  geom_point() +
  ylab("") +
  xlab("Variable Importance Measure") +
  theme_bw()


# choose top 10 importance variables
reduced_covs <- colnames(X_health)[order(importances, decreasing = T)[1:10]]

X_health_reduced <- X_health[, reduced_covs]


```


```{r predict_ge1_plugin, include = F, cache = T}

health_dat <- rhc %>% select(all_of(reduced_covs))

m0_health <- crossfit(health_dat %>% mutate(gamma = gamma_0),
                health_dat,
                as.formula(gamma ~ .), n.trees = 500,
                shrinkage = 0.01, interaction.depth = 1)

m0_hat_health <- m0_health$tr

m1_health <- crossfit(health_dat %>% mutate(gamma = gamma_1),
                health_dat,
                as.formula(gamma ~ .), n.trees = 500,
                shrinkage = 0.01, interaction.depth = 1)

m1_hat_health <- m1_health$tr


pos_class_plugin <- 1 * (m1_hat_health + m0_hat_health - 1 >= 0)


tau_pos_plugin <- 1 * (m1_hat_health - m0_hat_health >= 0)

```


```{r regret_helper_funcs}

val_idx <- c(1)

tr_idx <- setdiff(1:n, val_idx)



min_regret_soc_tree <- function(ub) {
  ug <- 1
  util <- compute_dr_utility(
    -ub + pos_class_plugin * (ub - 1),
    1 + pos_class_plugin * (ub - 1),
    (1 - ub) * pos_class_plugin,
    trt = rhc$trt,
    y = rhc$y,
    pscore = pscore,
    mu1 = mu1_hat,
    mu0 = mu0_hat
  )

  

  return(policytree::hybrid_policy_tree(X_health_reduced,
                                 cbind(0, util), depth = 3,
                                 search.depth = 2,
                                 split.step = 10))
}

```



```{r regret_relative_to_standard_of_care_tree, cache = T}

ubs_tree <- c(
  seq(0.5, 1, length.out = 21),
  seq(1.025, 2, length.out = 19)
  )

minmax_tree <- lapply(ubs_tree, min_regret_soc_tree)

symm_policy_tree <- min_regret_soc_tree(1)
```








```{r oracle_regret_helper_funcs}


# get plugin for min regret vs SOC
min_regret_soc_plugin <- function(ub) {
  ifelse(pos_class_plugin == 0,
        1 * (m1_hat_health >=  ub / 1 * m0_hat_health),
        1 * (m1_hat_health >= 1 / ub * m0_hat_health + (ub - 1) / ub))

}

min_regret_oracle_tree <- function(ub) {
  ug <- 1

  soc_class <- min_regret_soc_plugin(ub)
  utils <- get_regret_oracle(tau_pos_plugin, pos_class_plugin, soc_class, ug, ub)




  util <- compute_dr_utility(
    utils$c0, utils$c1, utils$c_const,
    trt = rhc$trt,
    y = rhc$y,
    pscore = pscore,
    mu1 = mu1_hat,
    mu0 = mu0_hat
  )

  

  return(policytree::hybrid_policy_tree(X_health_reduced,
                                 cbind(0, util), depth = 3,
                                 search.depth = 2,
                                 split.step = 10))

}


```



```{r regret_relative_to_oracle_tree, cache = T}

minmax_oracle_tree <- lapply(ubs_tree, min_regret_oracle_tree)

```



```{r compute_soc_plugin, include = F}



# for a given utility difference for bad outcomes, minimize regret
min_regret_soc_plugin <- function(ub) {
  ifelse(pos_class_plugin == 0,
        1 * (m1_hat_health >=  ub / 1 * m0_hat_health),
        1 * (m1_hat_health >= 1 / ub * m0_hat_health + (ub - 1) / ub))

}

metrics_plugin <- lapply(1:length(ubs_tree),
  function(i) {
    ug <- 1
    # compute worst-case regret
    ub <- ubs_tree[i]
    util <- compute_dr_utility(-ub - pos_class_plugin * (ug - ub),
                                  ug + pos_class_plugin * (ub - ug),
                                  (ug - ub) * pos_class_plugin,
                                  trt = rhc$trt, y = rhc$y,
                                  pscore = pscore, 
                                  mu1 = mu1_hat, mu0 = mu0_hat)
    trts <- min_regret_soc_plugin(ub)
    pct_trt <- mean(trts)

    prob01 <- compute_dr_prob_01(rhc$y, rhc$trt, pscore, mu1_hat,
                                mu0_hat, pos_class_plugin, trts)
    prob10 <- compute_dr_prob_10(rhc$y, rhc$trt, pscore, mu1_hat,
                                mu0_hat, pos_class_plugin, trts)

    avg_outcome <- compute_dr_avg_outcome(rhc$y, rhc$trt, pscore, mu1_hat,
                                          mu0_hat, trts)
    return(data.frame(pct_trt = pct_trt,
                      prob01 = prob01, prob10 = prob10, avg_outcome = avg_outcome))
  }) %>%
  bind_rows() %>%
  mutate(ub = ubs_tree)



metrics_tree <- lapply(1:length(ubs_tree),
  function(i) {
    ug <- 1
    # compute worst-case regret
    ub <- ubs_tree[i]
    trts <- predict(minmax_tree[[i]], X_health_reduced) - 1
    pct_trt <- mean(trts)

    prob01 <- compute_dr_prob_01(rhc$y, rhc$trt, pscore, mu1_hat,
                                mu0_hat, pos_class_plugin, trts)
    prob01_low <- compute_dr_prob_01_low(rhc$y, rhc$trt, pscore, mu1_hat,
                                mu0_hat, tau_pos_plugin, trts)
    prob10 <- compute_dr_prob_10(rhc$y, rhc$trt, pscore, mu1_hat,
                                mu0_hat, pos_class_plugin, trts)
    prob10_low <- compute_dr_prob_10_low(rhc$y, rhc$trt, pscore, mu1_hat,
                                mu0_hat, tau_pos_plugin, trts)

    avg_outcome <- compute_dr_avg_outcome(rhc$y, rhc$trt, pscore, mu1_hat,
                                          mu0_hat, trts)
    return(data.frame(pct_trt = pct_trt,
                      prob01 = prob01, prob10 = prob10,
                      prob01_low = prob01_low,
                      prob10_low = prob10_low,
                      avg_outcome = avg_outcome))
  }) %>%
  bind_rows() %>%
  mutate(ub = ubs_tree)



metrics_oracle_tree <- lapply(1:length(ubs_tree),
  function(i) {
    ug <- 1
    # compute worst-case regret
    ub <- ubs_tree[i]
    trts <- predict(minmax_oracle_tree[[i]], X_health_reduced) - 1
    pct_trt <- mean(trts)

    prob01 <- compute_dr_prob_01(rhc$y, rhc$trt, pscore, mu1_hat,
                                mu0_hat, pos_class_plugin, trts)
    prob10 <- compute_dr_prob_10(rhc$y, rhc$trt, pscore, mu1_hat,
                                mu0_hat, pos_class_plugin, trts)

    avg_outcome <- compute_dr_avg_outcome(rhc$y, rhc$trt, pscore, mu1_hat,
                                          mu0_hat, trts)
    return(data.frame(pct_trt = pct_trt,
                      prob01 = prob01, prob10 = prob10, avg_outcome = avg_outcome))
  }) %>%
  bind_rows() %>%
  mutate(ub = ubs_tree)


```

```{r pct_harmful_v_helpful_tree, fig.width = 7, fig.height = 4.5}

frontier <- conreg(metrics_tree$prob01, metrics_tree$prob10, convex = TRUE)



symm_trt_tree <- predict(symm_policy_tree, X_health_reduced) - 1

prob01_symm_tree <- compute_dr_prob_01(rhc$y, rhc$trt, pscore, mu1_hat,
                              mu0_hat, pos_class_plugin, symm_trt_tree)
prob10_symm_tree <- compute_dr_prob_10(rhc$y, rhc$trt, pscore, mu1_hat,
                              mu0_hat, pos_class_plugin, symm_trt_tree)
avg_outcome_symm_tree <- compute_dr_avg_outcome(rhc$y, rhc$trt, pscore, mu1_hat,
                                        mu0_hat, symm_trt_tree)

metrics_tree %>%
  filter(ub >= 0.5, ub <= 1.5) %>%
  mutate(prob01 = pmax(prob01, 0),
         prob10 = pmax(prob10, 0)) %>%
  ggplot(aes(x = prob01, y = prob10)) +
  geom_line() +
  geom_point(color = "grey70") +
  scale_color_distiller(type = "div") +
  # geom_line(aes(y = predict(frontier, metrics_tree$prob01))) +
  geom_point(color = "red",
    data = .%>% filter(ub == 1)
  ) +
  geom_point(color = "red",
    data = .%>% filter(ub == 1)
  ) +
  scale_x_continuous("Worst-case proportion given harmful treatment",
                     labels = scales::percent) +
  scale_y_continuous("Worst-case proportion failed to give useful treatment",
                    labels = scales::percent) + 
  theme_bw() +
  theme(strip.background = element_blank(),
        strip.placement = "outside") -> p
p
```


```{r pct_harmful_v_helpful_tree_square, fig.width = 4.5, fig.height = 4.5}

p +
  scale_x_continuous("Worst-case proportion given harmful treatment", labels = scales::percent, lim = c(0, .3)) +
  scale_y_continuous("Worst-case proportion failed to give useful treatment", labels = scales::percent, lim = c(0,.3)) 


```

```{r,  pct_treated_soc_tree, fig.width = 4.5, fig.height = 4.5}

metrics_tree %>%
  ggplot(aes(x = ub, y = pct_trt)) +
  geom_vline(xintercept = 1, lty = 2) +
  # geom_line() +
  geom_point(color = "grey70") +
  geom_point(color = "red", data = . %>% filter(ub == 1)) +
  geom_smooth(color = "black", se = FALSE, span = .35) + 
  ylab("% of patients assigned RHC") +
  scale_x_continuous(expression(u[l])) +
  scale_y_continuous(labels = scales::percent, limits = c(0,1), oob = scales::squish) +
  theme_bw()

```

```{r,  pct_treated_oracle_tree, fig.width = 4.5, fig.height = 4.5}

metrics_oracle_tree %>%
  ggplot(aes(x = ub, y = pct_trt)) +
  geom_vline(xintercept = 1, lty = 2) +
  geom_point(color = "grey70") +
  geom_point(color = "red", data = . %>% filter(ub == 1)) +
  geom_smooth(color = "black", se = FALSE, span = .35) + 
  ylab("% of patients assigned RHC") +
  scale_x_continuous(expression(u[l])) +
  scale_y_continuous(labels = scales::percent, limits = c(0,1), oob = scales::squish) +
  theme_bw()

```


```{r regret0_vs_pct_harmful, fig.width = 7, fig.height = 4.5}




metrics_tree %>%
  mutate(constr = ifelse(ub >= 1, prob01, prob10),
        fct = ifelse(ub >= 1, "Give harmful treatment", "Fail to give useful treatment")) %>%
  filter(constr >= 0)%>%
  ggplot(aes(x = constr, y = 1 - avg_outcome)) +
  geom_line() +
  geom_point(color = "grey70") +
  geom_point(color = "red", data = . %>% filter(ub == 1)) +
  geom_point(color = "red",
             data = . %>% filter(ub == 1) %>% mutate(fct = "Fail to give useful treatment", constr = prob10)) +
  facet_wrap(~ fct, scales = "free_x", strip.position = "bottom") +
  scale_x_continuous("Worst-case proportion", labels = scales::percent) +
  scale_y_continuous("Expected moratality", labels = scales::percent) + 
  scale_color_viridis_c() +
  theme_bw() +
  theme(strip.background = element_blank(),
        strip.placement = "outside")
```

```{r pct_harmful_v_helpful_plugin, fig.width = 7, fig.height = 4.5}

frontier <- conreg(metrics_plugin$prob01, metrics_plugin$prob10, convex = TRUE)


symm_trt_plugin <- m1_hat_health - m0_hat_health >= 0
symm_trt_plugin <- min_regret_soc_plugin(1)
symm_trt_plugin <- tau_hat_health >= 0

prob01_symm_plugin <- compute_dr_prob_01(rhc$y, rhc$trt, pscore, mu1_hat,
                              mu0_hat, pos_class_plugin, symm_trt_plugin)
prob10_symm_plugin <- compute_dr_prob_10(rhc$y, rhc$trt, pscore, mu1_hat,
                              mu0_hat, pos_class_plugin, symm_trt_plugin)
avg_outcome_symm_plugin <- compute_dr_avg_outcome(rhc$y, rhc$trt, pscore, mu1_hat,
                                        mu0_hat, symm_trt_plugin)

metrics_plugin %>%
  ggplot(aes(x = prob01, y = prob10)) +
  geom_point(aes(color = ub)) +
  scale_color_distiller(type = "div") +
  geom_line(aes(y = predict(frontier, metrics_plugin$prob01))) +
  geom_point(color = "red",
    data = data.frame(prob01 = prob01_symm_plugin, prob10 = prob10_symm_plugin)
  ) +
  geom_point(color = "red",
    data = .%>% filter(ub == 1)
  ) +
  scale_x_continuous("Worst-case proportion given harmful treatment", labels = scales::percent) +
  scale_y_continuous("Worst-case proportion failed to give useful treatment", labels = scales::percent) + 
  theme_bw() +
  theme(strip.background = element_blank(),
        strip.placement = "outside")
```

```{r,  pct_treated_soc_plugin, fig.width = 4.5, fig.height = 4.5}

metrics_plugin %>%
  ggplot(aes(x = ub, y = pct_trt)) +
  geom_vline(xintercept = 1, lty = 2) +
  geom_point(color = "grey70") +
  geom_point(color = "red", data = . %>% filter(ub == 1)) +
  geom_smooth(color = "black", se = FALSE, span = .35) + 
  ylab("% of patients assigned RHC") +
  scale_x_continuous(expression(u[l])) +
  scale_y_continuous(labels = scales::percent, limits = c(0,1), oob = scales::squish) +
  # coord_cartesian(ylim  =c(0,1)) +
  theme_bw()

```
